# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""An analysis that determines the reach of a function definition.

A function definition is said to reach a statement if that function may exist
(and therefore may be called) when that statement executes.
"""

import gast

from nvidia.dali._autograph.pyct import anno
from nvidia.dali._autograph.pyct import cfg
from nvidia.dali._autograph.pyct import transformer


class Definition(object):
    """Definition objects describe a unique definition of a function."""

    def __init__(self, def_node):
        self.def_node = def_node


class _NodeState(object):
    """Abstraction for the state of the CFG walk for reaching definition analysis.

    This is a value type. Only implements the strictly necessary operators.

    Attributes:
      value: Dict[qual_names.QN, Set[Definition, ...]], the defined symbols and
          their possible definitions
    """

    def __init__(self, init_from=None):
        if init_from:
            self.value = set(init_from)
        else:
            self.value = set()

    def __eq__(self, other):
        return self.value == other.value

    def __ne__(self, other):
        return self.value != other.value

    def __or__(self, other):
        assert isinstance(other, _NodeState)
        result = _NodeState(self.value)
        result.value.update(other.value)
        return result

    def __add__(self, value):
        result = _NodeState(self.value)
        result.value.add(value)
        return result

    def __repr__(self):
        return "NodeState[%s]=%s" % (id(self), repr(self.value))


class Analyzer(cfg.GraphVisitor):
    """CFG visitor that determines reaching definitions at statement level."""

    def __init__(self, graph, external_defs):
        super(Analyzer, self).__init__(graph)
        # This allows communicating that nodes have extra reaching definitions,
        # e.g. those that a function closes over.
        self.external_defs = external_defs

    def init_state(self, _):
        return _NodeState()

    def visit_node(self, node):
        prev_defs_out = self.out[node]

        if node is self.graph.entry:
            defs_in = _NodeState(self.external_defs)
        else:
            defs_in = prev_defs_out

        for n in node.prev:
            defs_in |= self.out[n]

        defs_out = defs_in
        if isinstance(node.ast_node, (gast.Lambda, gast.FunctionDef)):
            defs_out += node.ast_node

        self.in_[node] = defs_in
        self.out[node] = defs_out

        return prev_defs_out != defs_out


class TreeAnnotator(transformer.Base):
    """AST visitor that annotates each symbol name with its reaching definitions.

    Simultaneously, the visitor runs the dataflow analysis on each function node,
    accounting for the effect of closures. For example:

      def foo():
        def f():
          pass
        def g():
          # `def f` reaches here
    """

    def __init__(self, source_info, graphs):
        super(TreeAnnotator, self).__init__(source_info)
        self.graphs = graphs
        self.allow_skips = False
        self.current_analyzer = None

    def _proces_function(self, node):
        parent_analyzer = self.current_analyzer
        subgraph = self.graphs[node]

        if self.current_analyzer is not None and node in self.current_analyzer.graph.index:
            cfg_node = self.current_analyzer.graph.index[node]
            defined_in = self.current_analyzer.in_[cfg_node].value
        else:
            defined_in = ()

        analyzer = Analyzer(subgraph, defined_in)
        analyzer.visit_forward()

        self.current_analyzer = analyzer
        node = self.generic_visit(node)
        self.current_analyzer = parent_analyzer
        return node

    def visit_FunctionDef(self, node):
        return self._proces_function(node)

    def visit_Lambda(self, node):
        return self._proces_function(node)

    def visit(self, node):
        # This can happen before entering the top level function
        if self.current_analyzer is not None and node in self.current_analyzer.graph.index:
            cfg_node = self.current_analyzer.graph.index[node]
            anno.setanno(
                node, anno.Static.DEFINED_FNS_IN, self.current_analyzer.in_[cfg_node].value
            )

        extra_node = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST, default=None)
        if extra_node is not None:
            cfg_node = self.current_analyzer.graph.index[extra_node]
            anno.setanno(
                extra_node, anno.Static.DEFINED_FNS_IN, self.current_analyzer.in_[cfg_node].value
            )

        return super(TreeAnnotator, self).visit(node)


def resolve(node, source_info, graphs):
    """Resolves reaching definitions for each symbol.

    Args:
      node: ast.AST
      source_info: transformer.SourceInfo
      graphs: Dict[ast.FunctionDef, cfg.Graph]
    Returns:
      ast.AST
    """
    visitor = TreeAnnotator(source_info, graphs)
    node = visitor.visit(node)
    return node
